import torch

a = torch.rand((10,4))
print(a)
out = torch.chunk(a, 2, dim=1)
# print(out[0])
# print(out[1])
out = torch.split(a, 3, dim=0)
# print(out)

out = torch.split(a, [1, 3, 6], dim=0)
print(out)
